# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any
from typing import Optional
from typing import Union

from google.genai import types as genai_types
from pydantic import Field
from pydantic import model_validator
from typing_extensions import TypeAlias

from .app_details import AppDetails
from .common import EvalBaseModel
from .conversation_scenarios import ConversationScenario
from .eval_rubrics import Rubric


class IntermediateData(EvalBaseModel):
  """Container for intermediate data that an agent would generate as it responds with a final answer."""

  tool_uses: list[genai_types.FunctionCall] = []
  """Tool use trajectory in chronological order."""

  tool_responses: list[genai_types.FunctionResponse] = []
  """Tool response trajectory in chronological order."""

  intermediate_responses: list[tuple[str, list[genai_types.Part]]] = []
  """Intermediate responses generated by sub-agents to convey progress or status
  in a multi-agent system, distinct from the final response.

  This is expressed as a tuple of:
    - Author: Usually the sub-agent name that generated the intermediate
      response.

    - A list of Parts that comprise of the response.
  """


class InvocationEvent(EvalBaseModel):
  """An immutable record representing a specific point in the agent's invocation.

  It captures agent's replies, requests to use tools (function calls), and tool
  results.

  This structure is a simple projection of the actual `Event` datamodel that
  is intended for the Eval System.
  """

  author: str
  """The name of the agent that authored/owned this event."""

  content: Optional[genai_types.Content]
  """The content of the event."""


class InvocationEvents(EvalBaseModel):
  """A container for events that occur during the course of an invocation."""

  invocation_events: list[InvocationEvent] = Field(default_factory=list)
  """A list of invocation events."""


IntermediateDataType: TypeAlias = Union[IntermediateData, InvocationEvents]


class Invocation(EvalBaseModel):
  """Represents a single invocation."""

  invocation_id: str = ""
  """Unique identifier for the invocation."""

  user_content: genai_types.Content
  """Content provided by the user in this invocation."""

  final_response: Optional[genai_types.Content] = None
  """Final response from the agent."""

  intermediate_data: Optional[IntermediateDataType] = None
  """Intermediate steps generated as a part of Agent execution.

  For a multi-agent system, it is also helpful to inspect the route that
  the agent took to generate final response.
  """

  creation_timestamp: float = 0.0
  """Timestamp for the current invocation, primarily intended for debugging purposes."""

  rubrics: Optional[list[Rubric]] = Field(
      default=None,
  )
  """A list of rubrics that are applicable to only this invocation."""

  app_details: Optional[AppDetails] = Field(default=None)
  """Details about the App that was used for this invocation."""


SessionState: TypeAlias = dict[str, Any]
"""The state of the session."""


class SessionInput(EvalBaseModel):
  """Values that help initialize a Session."""

  app_name: str
  """The name of the app."""

  user_id: str
  """The user id."""

  state: SessionState = Field(default_factory=dict)
  """The state of the session."""


StaticConversation: TypeAlias = list[Invocation]
"""A conversation where the user's queries for each invocation are already specified."""


class EvalCase(EvalBaseModel):
  """An eval case."""

  eval_id: str
  """Unique identifier for the evaluation case."""

  conversation: Optional[StaticConversation] = None
  """A static conversation between the user and the Agent.

   While creating an eval case you should specify either a `conversation` or a
  `conversation_scenario`, but not both.
  """

  conversation_scenario: Optional[ConversationScenario] = None
  """A conversation scenario that should be used by a UserSimulator.

  While creating an eval case you should specify either a `conversation` or a
  `conversation_scenario`, but not both.
  """

  session_input: Optional[SessionInput] = None
  """Session input that will be passed on to the Agent during eval.
     It is common for Agents state to be initialized to some initial/default value,
     for example, your agent may need to know today's date.
  """

  creation_timestamp: float = 0.0
  """The time at which this eval case was created."""

  rubrics: Optional[list[Rubric]] = Field(
      default=None,
  )
  """A list of rubrics that are applicable to all the invocations in the conversation of this eval case."""

  final_session_state: Optional[SessionState] = Field(default_factory=dict)
  """The expected final session state at the end of the conversation."""

  @model_validator(mode="after")
  def ensure_conversation_xor_conversation_scenario(self) -> EvalCase:
    if (self.conversation is None) == (self.conversation_scenario is None):
      raise ValueError(
          "Exactly one of conversation and conversation_scenario must be"
          " provided in an EvalCase."
      )
    return self


def get_all_tool_calls(
    intermediate_data: Optional[IntermediateDataType],
) -> list[genai_types.FunctionCall]:
  """A utility method to retrieve tools calls from intermediate data."""
  if not intermediate_data:
    return []

  tool_calls = []
  if isinstance(intermediate_data, IntermediateData):
    tool_calls = intermediate_data.tool_uses
  elif isinstance(intermediate_data, InvocationEvents):
    # Go over each event in the list of events
    for invocation_event in intermediate_data.invocation_events:
      # Check if the event has content and some parts.
      if invocation_event.content and invocation_event.content.parts:
        for p in invocation_event.content.parts:
          # For each part, we check if any of those part is a function call.
          if p.function_call:
            tool_calls.append(p.function_call)
  else:
    raise ValueError(
        f"Unsupported type for intermediate_data `{intermediate_data}`"
    )

  return tool_calls


def get_all_tool_responses(
    intermediate_data: Optional[IntermediateDataType],
) -> list[genai_types.FunctionResponse]:
  """A utility method to retrieve tools responses from intermediate data."""
  if not intermediate_data:
    return []

  tool_responses = []
  if isinstance(intermediate_data, IntermediateData):
    tool_responses = intermediate_data.tool_responses
  elif isinstance(intermediate_data, InvocationEvents):
    # Go over each event in the list of events
    for invocation_event in intermediate_data.invocation_events:
      # Check if the event has content and some parts.
      if invocation_event.content and invocation_event.content.parts:
        for p in invocation_event.content.parts:
          # For each part, we check if any of those part is a function response.
          if p.function_response:
            tool_responses.append(p.function_response)
  else:
    raise ValueError(
        f"Unsupported type for intermediate_data `{intermediate_data}`"
    )

  return tool_responses


ToolCallAndResponse: TypeAlias = tuple[
    genai_types.FunctionCall, Optional[genai_types.FunctionResponse]
]
"""A Tuple representing a Function call and corresponding optional function response."""


def get_all_tool_calls_with_responses(
    intermediate_data: Optional[IntermediateDataType],
) -> list[ToolCallAndResponse]:
  """Returns tool calls with the corresponding responses, if available."""
  tool_responses_by_call_id: dict[str, genai_types.FunctionResponse] = {
      tool_response.id: tool_response
      for tool_response in get_all_tool_responses(intermediate_data)
  }

  tool_call_and_responses: list[ToolCallAndResponse] = []

  for tool_call in get_all_tool_calls(intermediate_data):
    response = tool_responses_by_call_id.get(tool_call.id, None)
    tool_call_and_responses.append((tool_call, response))

  return tool_call_and_responses
